{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Satellite imagery.ipynb",
      "provenance": [],
      "collapsed_sections": [
        "iVeYZpy3fF0N"
      ]
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iVeYZpy3fF0N",
        "colab_type": "text"
      },
      "source": [
        "\n",
        "##### Copyright 2020 Google LLC.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0eXL156ae-iT",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "@title License text\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-8-AUUdfGano",
        "colab_type": "text"
      },
      "source": [
        "### Initialize Tensor Flow and GPU devices, import modules"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "PBo1-ZMlCkun",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 51
        },
        "outputId": "9ad4077a-7ae6-4a00-81fb-83833d882561"
      },
      "source": [
        "import tensorflow as tf\n",
        "print(\"Tensorflow version \" + tf.__version__)\n",
        "device_name = tf.test.gpu_device_name()\n",
        "if device_name != '/device:GPU:0':\n",
        "  raise SystemError('GPU device not found')\n",
        "print(f'Found GPU at: {device_name}')"
      ],
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Tensorflow version 2.2.0\n",
            "Found GPU at: /device:GPU:0\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gRuhCmVOtfhJ",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!pip install opencv-python\n",
        "\n",
        "from concurrent import futures\n",
        "\n",
        "import io\n",
        "import os\n",
        "import re\n",
        "import tarfile\n",
        "\n",
        "import cv2\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "import matplotlib.colors as plt_colors\n",
        "import pandas as pd\n",
        "import tensorflow_datasets as tfds\n",
        "\n",
        "from tensorflow import keras\n",
        "from tensorflow.keras import layers\n",
        "\n",
        "from typing import Callable, Dict, Optional, Tuple\n",
        "Features = Dict[str, tf.Tensor]"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "s7xDGunUGyUu",
        "colab_type": "text"
      },
      "source": [
        "### Download raw images and annotation locally\n",
        "Khartoum city images from [Spacenet Buildings v2](https://https://spacenetchallenge.github.io/datasets/spacenetBuildings-V2summary.html) dataset are used. The task is to segment the instances of buidlings in the images.\n",
        "\n",
        "The SpaceNet Dataset by [SpaceNet Partners](https://https://spacenetchallenge.github.io/) is licensed under a [Creative Commons Attribution-ShareAlike 4.0 International License](https://http://creativecommons.org/licenses/by-sa/4.0/)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DXT2N95vDXQh",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Download a small part of the public Spacenet-v2 dataset.\n",
        "# The dataset structure is documented at https://spacenet.ai/khartoum/\n",
        "\n",
        "# NOTE: This cell takes a long time to execute. If colab is disconnected from\n",
        "# a runtime, all data is lost. Consider storing the unpacked gzip archive in\n",
        "# some external directory you can access (e.g. Google Cloud Storage bucket).\n",
        "\n",
        "DATASET_TAR = \"/tmp/AOI_5_Khartoum_train.tar.gz\"\n",
        "\n",
        "# Using tf.io.gfile allows to access AWS and GCS buckets directly from a colab.\n",
        "tf.io.gfile.copy(\"s3://spacenet-dataset/spacenet/SN2_buildings/tarballs/SN2_buildings_train_AOI_5_Khartoum.tar.gz\",\n",
        "                 DATASET_TAR)\n",
        "\n",
        "tf.io.gfile.mkdir(\"/tmp/spacenet\")\n",
        "with tarfile.open(DATASET_TAR) as tar_f:\n",
        "  tar_f.extractall(\"/tmp/spacenet\")\n",
        "\n",
        "tf.io.gfile.listdir(\"/tmp/spacenet/AOI_5_Khartoum_Train\")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0Z63d0UyHwUS",
        "colab_type": "text"
      },
      "source": [
        "### Create a TensorFlow Datasets Builder\n",
        "It automatcially converts raw data into TF-Records and gives easy access throgh `tf.data.Dataset` API.\n",
        "\n",
        "See more at:\n",
        "\n",
        "https://www.tensorflow.org/api_docs/python/tf/data/Dataset\n",
        "\n",
        "https://www.tensorflow.org/datasets\n",
        "\n",
        "https://www.tensorflow.org/datasets/api_docs/python/tfds/core/GeneratorBasedBuilder"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BQeCW8Sr3FjP",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "_DESCRIPTION = \"Spacenet (Khartoum only)\"\n",
        "# The directory were the raw data lives.\n",
        "_ROOT_DIR = \"/tmp/spacenet/AOI_5_Khartoum_Train\"\n",
        "\n",
        "# Min/Max RGB value ranges over data from Khartoum.\n",
        "# Needed for Spacenet dataset to convert pixel values into [0, 255] range.\n",
        "# This can be pre-calculated in advance given access to all images or might not\n",
        "# be needed for your dataset at all.\n",
        "_GLOBAL_MIN = np.array([1.0, 1.0, 23.0])\n",
        "_GLOBAL_MAX = np.array([1933.0, 2047.0, 1610.0])\n",
        "\n",
        "IMAGE_HEIGHT, IMAGE_WIDTH = 650, 650\n",
        "\n",
        "class SpacenetConfig(tfds.core.BuilderConfig):\n",
        "  \"\"\"BuilderConfig for spacenet.\"\"\"\n",
        "\n",
        "  def __init__(self, **kwargs):\n",
        "    \"\"\"Constructs a SpacenetConfig.\n",
        "\n",
        "    Args:\n",
        "      **kwargs: keyword arguments forwarded to super.\n",
        "    \"\"\"\n",
        "    # Version history:\n",
        "    super().__init__(version=tfds.core.Version(\"0.0.1\"), **kwargs)\n",
        "    self.train_path = _ROOT_DIR\n",
        "    self.min_val = _GLOBAL_MIN\n",
        "    self.max_val = _GLOBAL_MAX\n",
        "\n",
        "\n",
        "class Spacenet(tfds.core.GeneratorBasedBuilder):\n",
        "  \"\"\"Spacenet remote sensing dataset (Khartoum only).\"\"\"\n",
        "\n",
        "  BUILDER_CONFIGS = [\n",
        "      SpacenetConfig(name=\"Spacenet-Khartoum\",\n",
        "                     description=_DESCRIPTION)\n",
        "  ]\n",
        "\n",
        "  def __init__(self, data_dir: Optional[str] = None, **kwargs):\n",
        "    # NOTE: use your GCS bucket path here to persist TFRecords across multiple\n",
        "    # runs.\n",
        "    data_dir = data_dir or \"/tmp/spacenet/tensorflow_datasets\"\n",
        "    super().__init__(data_dir=data_dir, **kwargs)\n",
        "\n",
        "  def _info(self) -> tfds.core.DatasetInfo:\n",
        "    return tfds.core.DatasetInfo(\n",
        "        builder=self,\n",
        "        description=_DESCRIPTION,\n",
        "        features=tfds.features.FeaturesDict({\n",
        "            \"image\":\n",
        "                tfds.features.Image(\n",
        "                    shape=[IMAGE_HEIGHT, IMAGE_WIDTH, 3],\n",
        "                    encoding_format=\"jpeg\"),\n",
        "            \"segmentation_mask\":\n",
        "                tfds.features.Image(\n",
        "                    shape=[IMAGE_HEIGHT, IMAGE_WIDTH, 1],\n",
        "                    encoding_format=\"png\"),\n",
        "        }))\n",
        "\n",
        "  def _split_generators(self, dl_manager):\n",
        "    \"\"\"Returns SplitGenerators.\"\"\"\n",
        "    train_path = self.builder_config.train_path\n",
        "    return [\n",
        "        tfds.core.SplitGenerator(\n",
        "            name=tfds.Split.TRAIN,\n",
        "            gen_kwargs={\"root_path\": train_path},\n",
        "        ),\n",
        "    ]\n",
        "\n",
        "  def _generate_examples(self, root_path: str):\n",
        "    \"\"\"Yields examples from raw data.\"\"\"\n",
        "    max_per_channel = self.builder_config.max_val\n",
        "    min_per_channel = self.builder_config.min_val\n",
        "    path = os.path.join(root_path, \"RGB-PanSharpen\")\n",
        "    buildings_path = os.path.join(root_path, \"summaryData\")\n",
        "    # Reading polygons coordinates and label them with respect to the img number\n",
        "    csv_files = tf.io.gfile.glob(buildings_path + \"/*.csv\")\n",
        "    with tf.io.gfile.GFile(csv_files[0], \"r\") as fid:\n",
        "      df = pd.read_csv(fid)\n",
        "    df[\"image\"] = [x.split(\"_img\")[-1] for x in df.ImageId]\n",
        "    files = tf.io.gfile.glob(path + \"/*.tif\")\n",
        "    for filename in files:\n",
        "      # Extract the image ID XXX from \"RGB-PanSharpen_AOI_5_Khartoum_imgXXX.tif\"\n",
        "      buildings_filename = filename.split(\"_\")[-1].split(\".\")[0][3:]\n",
        "      yield filename, {\n",
        "          \"image\": _load_tif(filename, max_per_channel, min_per_channel),\n",
        "          \"segmentation_mask\": _load_mask(df, buildings_filename),\n",
        "      }\n",
        "\n",
        "\n",
        "def get_poly_coordinate(poly: str) -> np.ndarray:\n",
        "  \"\"\"Returns polygons coordinates as numpy array.\"\"\"\n",
        "  return np.array([\n",
        "      pp.split(\" \") for pp in re.findall(r\"[0-9.\\-]+ [0-9.\\-]+ [0-9.\\-]+\", poly)\n",
        "  ],\n",
        "                  dtype=np.float32)\n",
        "\n",
        "\n",
        "def _load_mask(df: pd.core.series.Series,\n",
        "               buildings_filename: str) -> np.ndarray:\n",
        "  \"\"\"Returns a loaded segmentation mask image.\"\"\"\n",
        "  mask = np.zeros((IMAGE_HEIGHT, IMAGE_WIDTH, 1), dtype=np.uint8)\n",
        "  buildings = df[df.image == buildings_filename]\n",
        "  for _, building in buildings.iterrows():\n",
        "    poly_coord = get_poly_coordinate(building.PolygonWKT_Pix)\n",
        "    if poly_coord.size > 0:\n",
        "      # Subindex polygon coordinate from [x, y, 0] to [x, y]\n",
        "      poly_coord = poly_coord[:, :2]\n",
        "      cv2.fillPoly(mask, [np.array(poly_coord, dtype=np.int32)], 1)\n",
        "  return mask.astype(np.uint8)\n",
        "\n",
        "\n",
        "def _load_tif(filename: str,\n",
        "              max_per_channel: np.ndarray,\n",
        "              min_per_channel: np.ndarray) -> np.ndarray:\n",
        "  \"\"\"Loads TIF file and returns as an image array in [0, 1].\"\"\"\n",
        "  with tf.io.gfile.GFile(filename, \"rb\") as fid:\n",
        "    img = tfds.core.lazy_imports.skimage.external.tifffile.imread(\n",
        "        io.BytesIO(fid.read())).astype(np.float32)\n",
        "  img = (img - min_per_channel) / (max_per_channel - min_per_channel) * 255\n",
        "  img = np.clip(img, 0, 255).astype(np.uint8)\n",
        "  return img\n",
        "\n",
        "# Convert raw data into TFRecord form and prepare for access.\n",
        "tfds_builder = Spacenet()\n",
        "tfds_builder.download_and_prepare()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZgmL-hMLJab6",
        "colab_type": "text"
      },
      "source": [
        "### Create an input pipeline\n",
        "A `create_dataset` function that batches, shuffles and preprocesses the dataset according to given parameters."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "T-KipKU8iBhK",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "AUTOTUNE = tf.data.experimental.AUTOTUNE\n",
        "\n",
        "def create_dataset(dataset_builder,\n",
        "                   split: str,\n",
        "                   preprocess_fn: Callable[[Features], Features],\n",
        "                   batch_size: int,\n",
        "                   num_epochs: Optional[int] = None,\n",
        "                   shuffle: bool = False,\n",
        "                   shuffle_buffer_size: int = 1000) -> tf.data.Dataset:\n",
        "  \"\"\"Returns a dataset to be used with TensorFlow2.\n",
        "\n",
        "  Args:\n",
        "    dataset_builder: `tfds.DatasetBuilder` object.\n",
        "    split: Name of the split to use. One of {'train', 'validation', 'test'}.\n",
        "    preprocess_fn: Callable for preprocessing.\n",
        "    batch_size: The batch size to use.\n",
        "    num_epochs: Number of epochs. See `tf.data.Dataset.repeat()`.\n",
        "    shuffle: Whether to shuffle examples in memory.\n",
        "    shuffle_buffer_size: Number of examples in the shuffle buffer.\n",
        "\n",
        "  Returns:\n",
        "    A `tf.data.Dataset` with the processed and batched features.\n",
        "  \"\"\"\n",
        "\n",
        "  read_config = tfds.ReadConfig(options=tf.data.Options())\n",
        "\n",
        "  ds = dataset_builder.as_dataset(\n",
        "      read_config=read_config,\n",
        "      split=split,\n",
        "      shuffle_files=shuffle)\n",
        "\n",
        "  ds = ds.repeat(num_epochs)\n",
        "  if shuffle:\n",
        "    ds = ds.shuffle(shuffle_buffer_size)\n",
        "  ds = ds.map(preprocess_fn, num_parallel_calls=AUTOTUNE)\n",
        "  ds = ds.batch(batch_size, drop_remainder=True)\n",
        "\n",
        "  return ds.prefetch(AUTOTUNE)\n"
      ],
      "execution_count": 10,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "K43zawpmJrG0",
        "colab_type": "text"
      },
      "source": [
        "### Define training, test and validation splits\n",
        "\n",
        "For simplicity the training data is randomly split into 3 parts of\n",
        "size 70%, 20% and 10% respectively. You would probably need a more\n",
        "complex splitting for the real data."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5IBr3Xi_SVYB",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "TRAIN_SPLIT=\"train[:70%]\"\n",
        "VAL_SPLIT=\"train[70%:90%]\"\n",
        "TEST_SPLIT=\"train[90%:]\"\n",
        "\n"
      ],
      "execution_count": 11,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "seZzAdx6Z5wQ",
        "colab_type": "text"
      },
      "source": [
        "### Take a look at the dataset\n",
        "Are there any problems? One might notice that there are shifted and merged instances."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Mp-Cb_dPllD6",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "BATCH_SIZE = 16\n",
        "ds = create_dataset(Spacenet(),\n",
        "                    split=TRAIN_SPLIT,\n",
        "                    shuffle=False,\n",
        "                    preprocess_fn = lambda x: x,\n",
        "                    batch_size = BATCH_SIZE)\n",
        "\n",
        "for batch in ds.take(1):\n",
        "  fig, axs = plt.subplots(nrows=BATCH_SIZE, ncols=2, figsize=(16, 8*BATCH_SIZE))\n",
        "  for i in range(BATCH_SIZE):\n",
        "    axs[i, 0].imshow(batch[\"image\"][i])\n",
        "    axs[i, 1].imshow(batch[\"image\"][i])\n",
        "    axs[i, 1].imshow(tf.squeeze(batch[\"segmentation_mask\"][i]), cmap='gray', alpha=0.3)\n",
        "\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xmoCX0ngaesp",
        "colab_type": "text"
      },
      "source": [
        "### Define preprocessing\n",
        "We're going to use 3 easy preprocessing techniques:\n",
        "* Scaling pixels to [0, 1] range.\n",
        "* Resizing an image to a fixed size of (448, 448).\n",
        "* Randomly adjusting the brightness of the image (as satellite imagery might be taken with different illumination around the world)\n",
        "\n",
        "We're going to skip the brightness adjustment for preprocessing\n",
        "validation and test data, but keep scaling and resizing.\n",
        "\n",
        "The preprocessing is done with a function that takes and example\n",
        "emitted by our input `tf.data.Dataset` and returns the same example\n",
        "preprocessed and in the Keras expected format (see https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit).\n",
        "\n",
        "Consider snapshotting API to save the preprocessed dataset on disk if preprocessing is the perofrmance bottleneck:\n",
        "`https://www.tensorflow.org/api_docs/python/tf/data/experimental/snapshot`"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "PhnRkMm8j5oW",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def preprocess_fn(features: Dict[str, tf.Tensor], is_training: bool) -> Tuple[tf.Tensor, tf.Tensor]:\n",
        "  \"\"\"Runs preprocessing and converts examples into a Keras compatible format.\"\"\"\n",
        "  image = features[\"image\"]\n",
        "  mask = features[\"segmentation_mask\"]\n",
        "  # Rescale the image to [0..1]\n",
        "  image = tf.cast(image, tf.float32) / 255.0\n",
        "  # Resize the image and mask to (448, 448).\n",
        "  # Round resize mask values to nearest integer.\n",
        "  image = tf.image.resize(image, (448, 448))\n",
        "  mask = tf.cast(tf.image.resize(mask, (448, 448)), tf.int32)\n",
        "  # If training, apply random brightness change.\n",
        "  if is_training:\n",
        "    image = tf.image.random_brightness(image, max_delta=0.2)\n",
        "  return image, mask\n",
        "\n",
        "\n",
        "train_preprocess_fn = functools.partial(preprocess_fn, is_training=True)\n",
        "validation_preprocess_fn = functools.partial(preprocess_fn, is_training=False)\n",
        "test_preprocess_fn = functools.partial(preprocess_fn, is_training=False)\n"
      ],
      "execution_count": 38,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d4nslLT4b1iU",
        "colab_type": "text"
      },
      "source": [
        "### Now taking a look at the preprocessed dataset\n",
        "\n",
        "This step is sanity checking that our preprocessing does what we expect.\n",
        "\n",
        "\n",
        "E.g. note the brightness adjustments."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zysXrq0l4YnY",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "train_ds = create_dataset(Spacenet(),\n",
        "                          split=TRAIN_SPLIT,\n",
        "                          shuffle=True,\n",
        "                          preprocess_fn=train_preprocess_fn,\n",
        "                          batch_size=BATCH_SIZE)\n",
        "for batch in train_ds.take(1):\n",
        "  fig, axs = plt.subplots(nrows=BATCH_SIZE, ncols=2, figsize=(16, 8*BATCH_SIZE))\n",
        "  for i in range(BATCH_SIZE):\n",
        "    axs[i, 0].imshow(batch[0][i])\n",
        "    axs[i, 1].imshow(tf.squeeze(batch[0][i]))\n",
        "    axs[i, 1].imshow(tf.squeeze(batch[1][i]), cmap='gray', alpha=0.3)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hmXgxj9H6AzQ",
        "colab_type": "text"
      },
      "source": [
        "### Define a convolutional model.\n",
        "\n",
        "Our model is going to consist from\n",
        "* Feature extractor (a bunch of convolutions and downsamplings)\n",
        "* Decdoer (a bunch of upsamplings and convolutions, followed by a fully connected head for each pixel)\n",
        "\n",
        "This modular architecture is common: the feature extactor can be swapped to another one easily.\n",
        "\n",
        "For classification, only feature extractor part would be needed (with\n",
        "a fully connected head for the class predictions)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "nQt0hWTuE2va",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Code adapted from: https://keras.io/examples/vision/oxford_pets_image_segmentation/\n",
        "# (Apache 2.0 License: https://github.com/keras-team/keras-io/blob/master/LICENSE)\n",
        "\n",
        "\n",
        "# A simple encoder-decoder model for semantic segmentation.\n",
        "# More on residual networks: https://arxiv.org/abs/1512.03385.\n",
        "\n",
        "def get_model(img_size, num_classes):\n",
        "    inputs = keras.Input(shape=img_size + (3,))\n",
        "\n",
        "    ### === Feature extractor ====\n",
        "    # This can be separately trained with a classfication head for pre-training.\n",
        "\n",
        "    # Entry block\n",
        "    x = layers.Conv2D(32, 3, strides=2, padding=\"same\")(inputs)\n",
        "    x = layers.BatchNormalization()(x)\n",
        "    x = layers.Activation(\"relu\")(x)\n",
        "\n",
        "    previous_block_activation = x  # Set aside residual\n",
        "\n",
        "    # Blocks 1, 2, 3 are identical apart from the feature depth.\n",
        "    for filters in [64, 128, 256]:\n",
        "        x = layers.Activation(\"relu\")(x)\n",
        "        x = layers.SeparableConv2D(filters, 3, padding=\"same\")(x)\n",
        "        x = layers.BatchNormalization()(x)\n",
        "\n",
        "        x = layers.Activation(\"relu\")(x)\n",
        "        x = layers.SeparableConv2D(filters, 3, padding=\"same\")(x)\n",
        "        x = layers.BatchNormalization()(x)\n",
        "\n",
        "        # Downscaling\n",
        "        x = layers.MaxPooling2D(3, strides=2, padding=\"same\")(x)\n",
        "\n",
        "        # Project residual\n",
        "        residual = layers.Conv2D(filters, 1, strides=2, padding=\"same\")(\n",
        "            previous_block_activation\n",
        "        )\n",
        "        x = layers.add([x, residual])  # Add back residual\n",
        "        previous_block_activation = x  # Set aside next residual\n",
        "\n",
        "    ### === Segmentation decoder ====\n",
        "    # Takes features generated by the feature extractor and produces\n",
        "    # Segmentation outputs.\n",
        "\n",
        "    previous_block_activation = x  # Set aside residual\n",
        "\n",
        "    for filters in [256, 128, 64, 32]:\n",
        "        x = layers.Activation(\"relu\")(x)\n",
        "        x = layers.Conv2DTranspose(filters, 3, padding=\"same\")(x)\n",
        "        x = layers.BatchNormalization()(x)\n",
        "\n",
        "        x = layers.Activation(\"relu\")(x)\n",
        "        x = layers.Conv2DTranspose(filters, 3, padding=\"same\")(x)\n",
        "        x = layers.BatchNormalization()(x)\n",
        "\n",
        "        # Upsacling\n",
        "        x = layers.UpSampling2D(2)(x)\n",
        "\n",
        "        # Project residual\n",
        "        residual = layers.UpSampling2D(2)(previous_block_activation)\n",
        "        residual = layers.Conv2D(filters, 1, padding=\"same\")(residual)\n",
        "        x = layers.add([x, residual])  # Add back residual\n",
        "        previous_block_activation = x  # Set aside next residual\n",
        "\n",
        "    # Add a per-pixel classification layer to assign segmentation classes.\n",
        "    outputs = layers.Conv2D(num_classes, 3, activation=\"softmax\", padding=\"same\")(x)\n",
        "\n",
        "    # Define the model\n",
        "    model = keras.Model(inputs, outputs)\n",
        "    return model"
      ],
      "execution_count": 16,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "34O4bxrxE_0H",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "model = get_model( (448, 448), 2)\n",
        "model.summary()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "62i9bzLnc-uJ",
        "colab_type": "text"
      },
      "source": [
        "### Do some training!\n",
        "Now let's create a validation dataset and do some training on GPU."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Gq5DdTWWFUDz",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "val_ds = create_dataset(Spacenet(),\n",
        "                        split=VAL_SPLIT,\n",
        "                        shuffle=False,\n",
        "                        preprocess_fn=validation_preprocess_fn,\n",
        "                        batch_size=BATCH_SIZE)\n",
        "with tf.device('/device:GPU:0'):\n",
        "  model = get_model( (448, 448), 2)\n",
        "  model.compile(optimizer='rmsprop', loss=\"sparse_categorical_crossentropy\")\n",
        "  model.fit(train_ds, epochs=10, steps_per_epoch=200, validation_data=val_ds, validation_steps=4)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ooqEoKs5dNOD",
        "colab_type": "text"
      },
      "source": [
        "### Looking at the training performance\n",
        "\n",
        "Let's see the model predictions on a batch of training data.\n",
        "As we can see, it is still not perfect and shows some patterns in the\n",
        "probelms it suffers."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6YgMpCfJMVpW",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "for batch in train_ds.take(1):\n",
        "  predictions = model.predict(batch[0])\n",
        "  fig, axs = plt.subplots(nrows=BATCH_SIZE, ncols=4, figsize=(16, 4*BATCH_SIZE))\n",
        "  for i in range(BATCH_SIZE):\n",
        "    axs[i, 0].imshow(batch[0][i])\n",
        "    axs[i, 1].imshow(tf.squeeze(batch[1][i]))\n",
        "    axs[i, 2].imshow(tf.squeeze(predictions[i, :, :, 1] > 0.5))\n",
        "    axs[i, 3].imshow(tf.squeeze(predictions[i, :, :, 1]))\n",
        "  axs[0,0].set_title('Image')\n",
        "  axs[0,1].set_title('Ground truth')\n",
        "  axs[0,2].set_title('Segmentation @0.5')\n",
        "  axs[0,3].set_title('Segmentation score')\n",
        "\n",
        "\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CL1DopkXjH6c",
        "colab_type": "text"
      },
      "source": [
        "\n",
        "### Looking at the validation performance\n",
        "\n",
        "The validation performance shows us how good the model is in generalizing beyond\n",
        " the training set."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8Pp7xOV_ikQ3",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "for batch in val_ds.take(1):\n",
        "  predictions = model.predict(batch[0])\n",
        "  fig, axs = plt.subplots(nrows=BATCH_SIZE, ncols=4, figsize=(16, 4*BATCH_SIZE))\n",
        "  for i in range(BATCH_SIZE):\n",
        "    axs[i, 0].imshow(batch[0][i])\n",
        "    axs[i, 1].imshow(tf.squeeze(batch[1][i]))\n",
        "    axs[i, 2].imshow(tf.squeeze(predictions[i, :, :, 1] > 0.5))\n",
        "    axs[i, 3].imshow(tf.squeeze(predictions[i, :, :, 1]))\n",
        "  axs[0,0].set_title('Image')\n",
        "  axs[0,1].set_title('Ground truth')\n",
        "  axs[0,2].set_title('Segmentation @0.5')\n",
        "  axs[0,3].set_title('Segmentation score')"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}